{
    "cells": [
        {
            "cell_type": "raw",
            "metadata": {},
            "source": [
                "---\n",
                "title: Mandelbrot in Mojo with Python plots\n",
                "description: Learn how to write high-performance Mojo code and import Python packages.\n",
                "website:\n",
                "  open-graph:\n",
                "    image: /static/images/mojo-social-card.png\n",
                "  twitter-card:\n",
                "    image: /static/images/mojo-social-card.png\n",
                "---"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "[//]: # REMOVE_FOR_WEBSITE\n",
                "*Copyright 2023 Modular, Inc: Licensed under the Apache License v2.0 with LLVM Exceptions.*"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "[//]: # REMOVE_FOR_WEBSITE\n",
                "# Mandelbrot in Mojo with Python plots\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Not only is Mojo great for writing high-performance code, but it also allows us to leverage the huge Python ecosystem of libraries and tools. With seamless Python interoperability, Mojo can use Python for what it's good at, especially GUIs, without sacrificing performance in critical code. Let's take the classic Mandelbrot set algorithm and implement it in Mojo.\n",
                "\n",
                "This tutorial shows two aspects of Mojo. First, it shows that Mojo can be used to develop fast programs for irregular applications. It also shows how we can leverage Python for visualizing the results."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "#|code-fold: true\n",
                "import benchmark\n",
                "from math import iota\n",
                "from sys import num_physical_cores\n",
                "from algorithm import parallelize, vectorize\n",
                "from complex import ComplexFloat64, ComplexSIMD\n",
                "from python import Python\n",
                "\n",
                "alias float_type = DType.float32\n",
                "alias int_type = DType.int32\n",
                "alias simd_width = 2 * simdwidthof[float_type]()\n",
                "alias unit = benchmark.Unit.ms"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "First set some parameters, you can try changing these to see different results:"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "alias width = 960\n",
                "alias height = 960\n",
                "alias MAX_ITERS = 200\n",
                "\n",
                "alias min_x = -2.0\n",
                "alias max_x = 0.6\n",
                "alias min_y = -1.5\n",
                "alias max_y = 1.5"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Here we define a simple `Matrix` struct:"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "@value\n",
                "struct Matrix[type: DType, rows: Int, cols: Int]:\n",
                "    var data: DTypePointer[type]\n",
                "\n",
                "    fn __init__(inout self):\n",
                "        self.data = DTypePointer[type].alloc(rows * cols)\n",
                "\n",
                "    fn __getitem__(self, row: Int, col: Int) -> Scalar[type]:\n",
                "        return self.data.load(row * cols + col)\n",
                "\n",
                "    fn store[width: Int = 1](self, row: Int, col: Int, val: SIMD[type, width]):\n",
                "        self.data.store[width=width](row * cols + col, val)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "The core [Mandelbrot](https://en.wikipedia.org/wiki/Mandelbrot_set) algorithm involves computing an iterative complex function for each pixel until it \"escapes\" the complex circle of radius 2, counting the number of iterations to escape:\n",
                "\n",
                "$$z_{i+1} = z_i^2 + c$$"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# Compute the number of steps to escape.\n",
                "def mandelbrot_kernel(c: ComplexFloat64) -> Int:\n",
                "    z = c\n",
                "    for i in range(MAX_ITERS):\n",
                "        z = z * z + c\n",
                "        if z.squared_norm() > 4:\n",
                "            return i\n",
                "    return MAX_ITERS\n",
                "\n",
                "\n",
                "def compute_mandelbrot() -> Matrix[float_type, height, width]:\n",
                "    # create a matrix. Each element of the matrix corresponds to a pixel\n",
                "    matrix = Matrix[float_type, height, width]()\n",
                "\n",
                "    dx = (max_x - min_x) / width\n",
                "    dy = (max_y - min_y) / height\n",
                "\n",
                "    y = min_y\n",
                "    for row in range(height):\n",
                "        x = min_x\n",
                "        for col in range(width):\n",
                "            matrix.store(row, col, mandelbrot_kernel(ComplexFloat64(x, y)))\n",
                "            x += dx\n",
                "        y += dy\n",
                "    return matrix"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Plotting the number of iterations to escape with some color gives us the canonical Mandelbrot set plot. To render it we can directly leverage Python's `matplotlib` right from Mojo!\n",
                "\n",
                "First install the required libraries:"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "%%python\n",
                "from importlib.util import find_spec\n",
                "import shutil\n",
                "import subprocess\n",
                "\n",
                "fix = \"\"\"\n",
                "-------------------------------------------------------------------------\n",
                "fix following the steps here:\n",
                "    https://github.com/modularml/mojo/issues/1085#issuecomment-1771403719\n",
                "-------------------------------------------------------------------------\n",
                "\"\"\"\n",
                "\n",
                "def install_if_missing(name: str):\n",
                "    if find_spec(name):\n",
                "        return\n",
                "    print(\"missing\", name)\n",
                "    print(f\"{name} not found, installing...\")\n",
                "    try:\n",
                "        if shutil.which('python3'): python = \"python3\"\n",
                "        elif shutil.which('python'): python = \"python\"\n",
                "        else: raise (\"python not on path\" + fix)\n",
                "        subprocess.check_call([python, \"-m\", \"pip\", \"install\", name])\n",
                "    except:\n",
                "        raise ImportError(f\"{name} not found\" + fix)\n",
                "\n",
                "install_if_missing(\"numpy\")\n",
                "install_if_missing(\"matplotlib\")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "def show_plot[type: DType](matrix: Matrix[type, height, width]):\n",
                "    alias scale = 10\n",
                "    alias dpi = 64\n",
                "\n",
                "    np = Python.import_module(\"numpy\")\n",
                "    plt = Python.import_module(\"matplotlib.pyplot\")\n",
                "    colors = Python.import_module(\"matplotlib.colors\")\n",
                "\n",
                "    numpy_array = np.zeros((height, width), np.float64)\n",
                "\n",
                "    for row in range(height):\n",
                "        for col in range(width):\n",
                "            numpy_array.itemset((row, col), matrix[row, col])\n",
                "\n",
                "    fig = plt.figure(1, [scale, scale * height // width], dpi)\n",
                "    ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], False, 1)\n",
                "    light = colors.LightSource(315, 10, 0, 1, 1, 0)\n",
                "\n",
                "    image = light.shade(numpy_array, plt.cm.hot, colors.PowerNorm(0.3), \"hsv\", 0, 0, 1.5)\n",
                "    plt.imshow(image)\n",
                "    plt.axis(\"off\")\n",
                "    plt.show()\n",
                "\n",
                "show_plot(compute_mandelbrot())"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Vectorizing Mandelbrot\n",
                "We showed a naive implementation of the Mandelbrot algorithm, but there are two things we can do to speed it up. We can early-stop the loop iteration when a pixel is known to have escaped, and we can leverage Mojo's access to hardware by vectorizing the loop, computing multiple pixels simultaneously. To do that we will use the `vectorize` higher order generator.\n",
                "\n",
                "We start by defining our main iteration loop in a vectorized fashion"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "fn mandelbrot_kernel_SIMD[\n",
                "    simd_width: Int\n",
                "](c: ComplexSIMD[float_type, simd_width]) -> SIMD[int_type, simd_width]:\n",
                "    \"\"\"A vectorized implementation of the inner mandelbrot computation.\"\"\"\n",
                "    var cx = c.re\n",
                "    var cy = c.im\n",
                "    var x = SIMD[float_type, simd_width](0)\n",
                "    var y = SIMD[float_type, simd_width](0)\n",
                "    var y2 = SIMD[float_type, simd_width](0)\n",
                "    var iters = SIMD[int_type, simd_width](0)\n",
                "\n",
                "    var t: SIMD[DType.bool, simd_width] = True\n",
                "    for _ in range(MAX_ITERS):\n",
                "        if not any(t):\n",
                "            break\n",
                "        y2 = y * y\n",
                "        y = x.fma(y + y, cy)\n",
                "        t = x.fma(x, y2) <= 4\n",
                "        x = x.fma(x, cx - y2)\n",
                "        iters = t.select(iters + 1, iters)\n",
                "    return iters"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "The above function is parameterized on the `simd_width` and processes simd_width pixels. It only escapes once all pixels within the vector lane are done. We can use the same iteration loop as above, but this time we vectorize within each row instead. We use the `vectorize` generator to make this a simple function call. The benchmark can run in parallel or just vectorized."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "fn run_mandelbrot(parallel: Bool) raises -> Float64:\n",
                "    var matrix = Matrix[int_type, height, width]()\n",
                "\n",
                "    @parameter\n",
                "    fn worker(row: Int):\n",
                "        var scale_x = (max_x - min_x) / width\n",
                "        var scale_y = (max_y - min_y) / height\n",
                "\n",
                "        @parameter\n",
                "        fn compute_vector[simd_width: Int](col: Int):\n",
                "            \"\"\"Each time we operate on a `simd_width` vector of pixels.\"\"\"\n",
                "            var cx = min_x + (col + iota[float_type, simd_width]()) * scale_x\n",
                "            var cy = min_y + row * scale_y\n",
                "            var c = ComplexSIMD[float_type, simd_width](cx, cy)\n",
                "            matrix.store(row, col, mandelbrot_kernel_SIMD[simd_width](c))\n",
                "\n",
                "        # Vectorize the call to compute_vector where call gets a chunk of pixels.\n",
                "        vectorize[compute_vector, simd_width](width)\n",
                "\n",
                "    @parameter\n",
                "    fn bench():\n",
                "        for row in range(height):\n",
                "            worker(row)\n",
                "\n",
                "    @parameter\n",
                "    fn bench_parallel():\n",
                "        parallelize[worker](height, height)\n",
                "\n",
                "    var time: Float64 = 0\n",
                "    if parallel:\n",
                "        time = benchmark.run[bench_parallel](max_runtime_secs=0.5).mean(unit)\n",
                "    else:\n",
                "        time = benchmark.run[bench](max_runtime_secs=0.5).mean(unit)\n",
                "\n",
                "    show_plot(matrix)\n",
                "    matrix.data.free()\n",
                "    return time\n",
                "\n",
                "vectorized = run_mandelbrot(parallel=False)\n",
                "print(\"Vectorized:\", vectorized, unit)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Parallelizing Mandelbrot\n",
                "While the vectorized implementation above is efficient, we can get better performance by parallelizing on the cols. This again is simple in Mojo using the `parallelize` higher order function:"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "parallelized = run_mandelbrot(parallel=True)\n",
                "print(\"Parallelized:\", parallelized, unit)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Benchmarking\n",
                "\n",
                "In this section we compare the vectorized speed to the parallelized speed "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "print(\"Number of physical cores:\", num_physical_cores())\n",
                "print(\"Vectorized:\", vectorized, \"ms\")\n",
                "print(\"Parallelized:\", parallelized, \"ms\")\n",
                "#| CHECK: speedup\n",
                "print(\"Parallel speedup:\", vectorized / parallelized)"
            ]
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "Mojo (nightly)",
            "language": "mojo",
            "name": "mojo-nightly-jupyter-kernel"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "mojo"
            },
            "file_extension": ".mojo",
            "mimetype": "text/x-mojo",
            "name": "mojo"
        },
        "orig_nbformat": 4
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
